from typing import Any, Optional, TypeVar
from ... import Tensor
from ..modules import Module


class SpectralNorm:
    name: str = ...
    dim: int = ...
    n_power_iterations: int = ...
    eps: float = ...

    def __init__(self, name: str = ..., n_power_iterations: int = ..., dim: int = ..., eps: float = ...) -> None: ...

    def reshape_weight_to_matrix(self, weight: Tensor) -> Tensor: ...

    def compute_weight(self, module: Module, do_power_iteration: bool) -> Tensor: ...

    def remove(self, module: Module) -> None: ...

    def __call__(self, module: Module, inputs: Any) -> None: ...

    @staticmethod
    def apply(module: Module, name: str, n_power_iterations: int, dim: int, eps: float) -> 'SpectralNorm': ...


T_module = TypeVar('T_module', bound=Module)


def spectral_norm(module: T_module, name: str = ..., n_power_iterations: int = ..., eps: float = ...,
                  dim: Optional[int] = ...) -> T_module: ...


def remove_spectral_norm(module: T_module, name: str = ...) -> T_module: ...
